{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-05T06:49:23.088133Z",
     "start_time": "2020-10-05T06:49:04.200482Z"
    }
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "import math\n",
    "import time\n",
    "import argparse\n",
    "import random\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import networkx as nx\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.nn.parameter import Parameter\n",
    "from torch.nn.functional import gumbel_softmax\n",
    "from utils import gumbel_softmax_3d\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-05T06:49:25.907419Z",
     "start_time": "2020-10-05T06:49:25.903464Z"
    }
   },
   "outputs": [],
   "source": [
    "def load_data(path):\n",
    "    with open(path, 'rb') as f:\n",
    "        data = pickle.load(f)\n",
    "    G = nx.from_dict_of_lists(data)\n",
    "    return G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-05T06:49:30.592914Z",
     "start_time": "2020-10-05T06:49:30.558986Z"
    }
   },
   "outputs": [],
   "source": [
    "def train(args, G):\n",
    "    bs = args.batch_size\n",
    "    n = G.number_of_nodes()\n",
    "    A = nx.to_numpy_matrix(G)\n",
    "    A = torch.from_numpy(A)\n",
    "    A = A.type(torch.float32)\n",
    "    best_log = []\n",
    "    A = A.cuda()\n",
    "\n",
    "    # set parameters\n",
    "    if torch.cuda.is_available():\n",
    "        A = A.cuda()\n",
    "        x = torch.randn(bs, n, 1, device='cuda')*1e-5\n",
    "        x.requires_grad = True\n",
    "    else:\n",
    "        x = torch.randn(bs, n, 1, requires_grad=True)*1e-5\n",
    "\n",
    "    # set optimizer\n",
    "    optimizer = torch.optim.Adam([x], lr=args.lr)\n",
    "\n",
    "    # training\n",
    "    cost_arr = []\n",
    "    for _ in range(args.iterations):\n",
    "        optimizer.zero_grad()\n",
    "        if torch.cuda.is_available():\n",
    "            probs = torch.empty(bs, n, 2, device='cuda')\n",
    "        else:\n",
    "            probs = torch.empty(bs, n, 2)\n",
    "        p = torch.sigmoid(x)\n",
    "        probs[:, :, 0] = p.squeeze()\n",
    "        probs[:, :, -1] = 1-probs[:, :, 0]\n",
    "        logits = torch.log(probs+1e-10)\n",
    "        s = gumbel_softmax_3d(logits, tau=args.tau, hard=args.hard)[:, :, 0]\n",
    "        s = torch.unsqueeze(s, -1)  # size [bs, n, 1]\n",
    "        cost = torch.sum(s)\n",
    "        constraint = torch.sum((1-torch.transpose(s, 1, 2)) @ A @ (1-s))\n",
    "        loss = cost + args.eta * constraint\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        with torch.no_grad():\n",
    "            constraint_ = torch.squeeze((1-torch.transpose(s, 1, 2)) @ A @ (1-s))\n",
    "            constraint = constraint_.cpu().numpy()\n",
    "            idx = np.argwhere(constraint == 0)  # select constraint=0\n",
    "\n",
    "            \n",
    "            cost_ = torch.sum(s, dim=1).squeeze()\n",
    "\n",
    "            loss_ = cost_ + constraint_\n",
    "            \n",
    "            if len(idx) != 0:\n",
    "\n",
    "                s = gumbel_softmax_3d(logits, tau=args.tau, hard=True)[:, :, 0]\n",
    "                s = torch.unsqueeze(s, -1).cpu()  # size [bs, n, 1]\n",
    "                cost_ = torch.sum(s, dim=1).squeeze()\n",
    "                cost = torch.sum(s, dim=1)[idx.reshape(-1,)]\n",
    "                # from size [bs, 1] select constrain=0\n",
    "                \n",
    "                if _ > 5000 and torch.min(cost) < min(cost_arr):\n",
    "                    k = np.argwhere(torch.min(cost.cpu()))\n",
    "                    \n",
    "                    print('##################################################')\n",
    "                    print('####################',_,'####################')\n",
    "                    print(idx.squeeze())\n",
    "                    print(cost_[idx.squeeze()])\n",
    "                    print('##################################################')\n",
    "\n",
    "                \n",
    "                cost_arr.append(torch.min(cost.cpu()))\n",
    "                \n",
    "            #交叉变异\n",
    "            if _ % 10000 == 0 and _ >= 10000:\n",
    "                \n",
    "                ratio = 1/4\n",
    "                r = int(ratio * bs)\n",
    "                m = 2e-3\n",
    "                \n",
    "                maxdx = torch.argsort(loss_, dim=0,descending=True).squeeze()\n",
    "                mindx = torch.argsort(loss_, dim=0,descending=False).squeeze()\n",
    "\n",
    "                #print('maxdx:',maxdx)\n",
    "                #print('mindx:',mindx)\n",
    "                L1 = random.sample(list(maxdx[0:r]), r)\n",
    "                L2 = random.sample(list(mindx[0:r]), r)\n",
    "                print('L1:',L1)\n",
    "                #print('L2:',L2)\n",
    "                #print(loss_[L1[0]])\n",
    "                for i in range(r//2):\n",
    "                    #print(L1[i])\n",
    "                    #print('before:',nnn.pps.data[L1[i], : ,0])\n",
    "                    #print(L2[i])\n",
    "                    #print('father:',nnn.pps.data[L2[i], :, 0])\n",
    "                    #print(L2[-(i+1)])\n",
    "                    #print('mother:',nnn.pps.data[L2[-(i+1)], :, 0])\n",
    "                    rand = torch.rand(n).cuda()\n",
    "                    x.data[L1[i], : ,0] = torch.where(rand < 0.5, \n",
    "                                                            x.data[L2[i], :, 0], \n",
    "                                                            x.data[L2[r-1-i], :, 0])\n",
    "                    x.data[L1[r-1-i], :, 0] = torch.where(rand < 0.5, \n",
    "                                                            x.data[L2[r-1-i], :, 0], \n",
    "                                                            x.data[L2[i], :, 0])\n",
    "                    #print('after:',nnn.pps.data[L1[i], : ,0])\n",
    "                    rand = torch.rand(n).cuda()\n",
    "                    x.data[L1[i], : ,0] = torch.where(rand < m, \n",
    "                                                            x.data[L1[i], : ,0], \n",
    "                                                            torch.randn_like(x.data[L1[i], : ,0]) * 1e-5)\n",
    "                    #count = torch.where(rand<m,torch.ones_like(rand),torch.zeros_like(rand))\n",
    "                    #print('mutation_count:',torch.sum(count))\n",
    "                \n",
    "            if _ % 1000 == 0:\n",
    "                print(_)\n",
    "                if len(cost_arr) != 0:\n",
    "\n",
    "                    print('constraint:',torch.sort(constraint_))\n",
    "                    print('loss:',torch.sort(loss_))\n",
    "                    print('# {}, cost: {}'.format(_, ((np.sort(cost_arr))[0:8])))\n",
    "                else:\n",
    "                    print('Failed!')\n",
    "\n",
    "    return cost_arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-05T06:49:47.614354Z",
     "start_time": "2020-10-05T06:49:36.399387Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-4-df5a154d9fec>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     39\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     40\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'__main__'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 41\u001b[1;33m     \u001b[0mmain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m<ipython-input-4-df5a154d9fec>\u001b[0m in \u001b[0;36mmain\u001b[1;34m()\u001b[0m\n\u001b[0;32m     31\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     32\u001b[0m     \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mensemble\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 33\u001b[1;33m         \u001b[0mcost\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mG\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     34\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcost\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m!=\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     35\u001b[0m             \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'# {}, cost: {}'\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcost\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-3-c03b46bf5a5c>\u001b[0m in \u001b[0;36mtrain\u001b[1;34m(args, G)\u001b[0m\n\u001b[0;32m      6\u001b[0m     \u001b[0mA\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mA\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtype\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      7\u001b[0m     \u001b[0mbest_log\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 8\u001b[1;33m     \u001b[0mA\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mA\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      9\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     10\u001b[0m     \u001b[1;31m# set parameters\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "def main():\n",
    "    # Training settings\n",
    "    parser = argparse.ArgumentParser(\n",
    "        description='Solving MIS problems (with fixed tau in GS, parallel version)')\n",
    "    parser.add_argument('--batch-size', type=int, default=128,\n",
    "                        help='batch size (default: 128)')\n",
    "    parser.add_argument('--data', type=str, default='citeseer',\n",
    "                        help='data name (default: cora)')\n",
    "    parser.add_argument('--tau', type=float, default=1.,\n",
    "                        help='tau value in Gumbel-softmax (default: 1)')\n",
    "    parser.add_argument('--hard', type=bool, default=True,\n",
    "                        help='hard sampling in Gumbel-softmax (default: True)')\n",
    "    parser.add_argument('--lr', type=float, default=1e-2,\n",
    "                        help='learning rate (default: 1e-2)')\n",
    "    parser.add_argument('--eta', type=float, default=3.,\n",
    "                        help='constraint (default: 5)')\n",
    "    parser.add_argument('--ensemble', type=int, default=10,\n",
    "                        help='# experiments (default: 100)')\n",
    "    parser.add_argument('--iterations', type=int, default=1000000,\n",
    "                        help='# iterations in gradient descent (default: 20000)')\n",
    "    parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')\n",
    "    args = parser.parse_args(args=[])\n",
    "\n",
    "    # torch.manual_seed(args.seed)\n",
    "    use_cuda = torch.cuda.is_available()\n",
    "    device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "    print(device)\n",
    "\n",
    "    # loading data\n",
    "    G = load_data('./data/ind.' + args.data + '.graph')\n",
    "\n",
    "    for i in range(args.ensemble):\n",
    "        cost = train(args, G)\n",
    "        if len(cost) != 0:\n",
    "            print('# {}, cost: {}'.format(i, min(cost)))\n",
    "        else:\n",
    "            print('Failed!')\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
